"""
@author: chenzhenhua
@project: jf_fashion
@file: mnist.py
@time: 2021/8/2 0002 11:15
@desc:
"""

import unittest

import torch
from torch.optim import SGD

from jf_fashion.pytorch.mnist import Mnist
from jf_fashion.pytorch.mnist_net import Net


class TestCase(unittest.TestCase):

    #@unittest.skip("")
    def test_mnist_train(self):
        random_seed = 1
        torch.manual_seed(random_seed)
        data_path = 'data/'
        batch_size_train = 64
        batch_size_test = 1000
        learning_rate = 0.01
        momentum = 0.5
        n_epochs = 3
        log_interval = 10
        network = Net()
        optimizer = SGD(network.parameters(), lr=learning_rate, momentum=momentum)

        mnist = Mnist(data_path)
        mnist.train(network, optimizer, n_epochs, log_interval, batch_size_train)
        mnist.test(network, batch_size_test)


if __name__ == '__main__':
    unittest.main()
